"""
Combined 2x2 plot of Incoherent Internal Tide Total Variance and Fraction.
Row 1: Total Variance (Semidiurnal, Diurnal)
Row 2: Fraction (Semidiurnal, Diurnal)

B. Yadidya
Feb 24, 2026
"""

import sys
sys.path.append('.')
from common_imports import *

dir = "data/"
ds = xr.open_dataset(dir + 'fraction_temporal_variance_incoherent_ITs_hycom.nc')

# --- 3. Calculate Figure Size Based on Journal Guidelines ---
fig_width_cm = 18.4
fig_width_in = fig_width_cm / 2.54
fig_height_max_cm = 7.8
fig_height_in = fig_height_max_cm / 2.54 

fig, axs = plt.subplots(
    2, 2, 
    figsize=(fig_width_in, fig_height_in), 
    constrained_layout=True
)
axs = axs.ravel()

# Define colormaps
cmap1 = cmp.WhiteBlueGreenYellowRed # For variance
cmap2 = 'plasma'           # For fraction

# --- 5. Define Plot Settings ---
plot_configs = [
    # Row 1: Total Variance
    {
        "data": ds['incoh_var_sd'],
        "title": 'Semidiurnal Non-phase-locked Internal Tide',
        "label": "A",
        "levels": [0, .15, .25, .5, .75, 1, 2, 4],
        "cmap": cmap1,
        "unit": 'cm$^2$',
        "ticks": None 
    },
    {
        "data": ds['incoh_var_dr'],
        "title": 'Diurnal Non-phase-locked Internal Tide',
        "label": "B",
        "levels": [0, .15, .25, .5, .75, 1, 2, 4],
        "cmap": cmap1,
        "unit": 'cm$^2$',
        "ticks": None
    },
    # Row 2: Fraction
    {
        "data": ds['incoh_frac_sd'],
        "title": 'Semidiurnal Fraction',
        "label": "C",
        "levels": np.linspace(0, 1, 11),
        "cmap": cmap2,
        "unit": '',
        "ticks": [0, 0.5, 1]
    },
    {
        "data": ds['incoh_frac_dr'],
        "title": 'Diurnal Fraction',
        "label": "D",
        "levels": np.linspace(0, 1, 11),
        "cmap": cmap2,
        "unit": '',
        "ticks": [0, 0.5, 1]
    }
]

# --- 6. Loop Through Subplots ---
mappables = []
for i, (ax, config) in enumerate(zip(axs, plot_configs)):
    # Plot the data
    p = config["data"].plot.contourf(
        ax=ax,
        levels=config["levels"],
        x='longitude',
        cmap=config["cmap"],
        add_colorbar=False  # Disable individual colorbars
    )
    mappables.append(p)

    # Set common map properties
    plot_global_map_on_axes_basemap(ax)
    ax.set_aspect('equal')
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_ylim(-60, 60)
    if i in [0, 1]:
        ax.set_title(config["title"], fontsize=9)
    else:
        ax.set_title('')

# --- 7. Final Formatting ---
# Define panel label style properties
panel_label_props = dict(
    fontsize=9,
    fontweight='bold',
    ha='center',
    va='center',
    bbox=dict(boxstyle='round, pad=0.1', facecolor='white', edgecolor='none', alpha=1)
)

# Place panel labels ('A', 'B', 'C', 'D')
for i, config in enumerate(plot_configs):
    axs[i].text(0.021, 0.92, config["label"], transform=axs[i].transAxes, **panel_label_props)

# Add Row Labels
label_kwargs = dict(fontsize=9, ha='right', va='center', rotation=90, fontstyle='italic')
axs[0].text(-0.01, 0.5, 'Temporal Variance', transform=axs[0].transAxes, **label_kwargs)
axs[2].text(-0.01, 0.5, 'Fraction', transform=axs[2].transAxes, **label_kwargs)

# --- Add Shared Colorbars ---
# Top Row Colorbar (Variance) - Associated with axs[0] and axs[1]
cbar1 = fig.colorbar(mappables[0], ax=axs[:2], location='bottom', shrink=0.35, pad=0.03, extend = 'max', aspect=50)
cbar1.ax.text(1.06, 0, 'cm$^2$', transform=cbar1.ax.transAxes, va='center', ha='left', fontsize=8)
cbar1.ax.tick_params(which='major', labelsize=7, size=0, width=0)
cbar1.ax.minorticks_off()

# Bottom Row Colorbar (Fraction) - Associated with axs[2] and axs[3]
cbar2 = fig.colorbar(mappables[2], ax=axs[2:], location='bottom', shrink=0.35, pad=0.03, extend = False, aspect=50)
cbar2.set_label('', size=7)
cbar2.set_ticks([0, 0.5, 1])
cbar2.ax.tick_params(which='major', labelsize=7, size=0, width=0)
cbar2.ax.minorticks_off()

# --- 8. Save the Figure ---
plt.savefig('Fig5_incoherent_IT_variance_fraction.png', dpi=500, bbox_inches='tight')
